import climate_learn as cl
import numpy as np
import matplotlib.pyplot as plt
import os

# 1. 下载数据（已经跑通了）
cl.data.download_weatherbench(
    dst="./temperature_850",
    dataset="era5",
    variable="temperature_850"
)
# 运行目录需要放在根目录
print("Data downloaded successfully.")

# 2. 数据目录
data_dir = "./temperature_850"

# 3. WeatherBench数据下载后会有npz文件，读取npz中的温度数据
npz_files = [f for f in os.listdir(data_dir) if f.endswith('.npz')]
npz_files.sort()

# 4. 读取第一个 npz 文件数据
data = np.load(os.path.join(data_dir, npz_files[0]))

# WeatherBench标准中，温度变量名通常是 'temp' 或 'temperature'
print(data.files)  # 打印所有键，确认变量名

temp = data['temp']  # shape 一般是 (time, lat, lon)

# 5. 输出前4帧温度图像
for i in range(min(4, temp.shape[0])):
    plt.figure(figsize=(8, 4))
    plt.imshow(temp[i], cmap='coolwarm')
    plt.colorbar(label='Temperature (K)')
    plt.title(f'Frame {i}')
    plt.tight_layout()
    plt.show()
